﻿#pragma kernel ClearChunks
#pragma kernel ClearGrid
#pragma kernel InsertChunks
#pragma kernel SortParticles
#pragma kernel FindPopulatedLevels

#include "MathUtils.cginc"
#include "Bounds.cginc"
#include "GridUtils.cginc"
#include "FluidChunkDefs.cginc"
#include "SolverParameters.cginc"

// particle data:
StructuredBuffer<int> particleIndices;
StructuredBuffer<float4> positions;
StructuredBuffer<float4> velocities;
StructuredBuffer<float4> angularVelocities;
StructuredBuffer<float4> principalRadii;
StructuredBuffer<float4> fluidMaterial;
StructuredBuffer<float4> fluidData;
StructuredBuffer<quaternion> orientations;
StructuredBuffer<float4> colors;
StructuredBuffer<float4> normals;

RWStructuredBuffer<float4> sortedPositions;   
RWStructuredBuffer<float4> sortedPrincipalRadii;  
RWStructuredBuffer<float4> sortedVelocities;  
RWStructuredBuffer<quaternion> sortedOrientations;   
RWStructuredBuffer<float4> sortedColors;   

// particle grid data:
StructuredBuffer<aabb> solverBounds;
StructuredBuffer<int> gridHashToSortedIndex; 
RWStructuredBuffer<uint> cellOffsets;    // start of each cell in the sorted item array.
RWStructuredBuffer<uint> cellCounts;     // number of item in each cell.
RWStructuredBuffer<uint> offsetInCell;   // for each item, offset within its the cell.

// voxel data:
RWStructuredBuffer<int3> chunkCoords; // for each chunk, spatial coordinates.
RWStructuredBuffer<keyvalue> hashtable; // size: maxChunks entries.

RWStructuredBuffer<uint> dispatchBuffer; 

uint firstParticle;
uint particleCount;
float isosurface;
float smoothing;

uint AllocateChunk(uint3 coords)
{
    uint key = VoxelID(coords);
    uint slot = hash(key);
    
    [allow_uav_condition]
    for (uint i = 0; i < maxChunks; ++i) // at most, check the entire table.
    {
        uint prev;
        InterlockedCompareExchange(hashtable[slot].key, INVALID, key, prev);

        // allocate new chunk:
        if (prev == INVALID) 
        {
            InterlockedAdd(dispatchBuffer[4],1,hashtable[slot].handle);
            chunkCoords[hashtable[slot].handle] = coords;
            return hashtable[slot].handle;
        }
        // could not allocate chunk, since it already exists.
        else if (prev == key) 
        {
            return INVALID;
        }
        // collision, try next slot.
        else 
            slot = (slot + 1) % maxChunks;
    }
    return INVALID; // could not allocate chunk, not enough space.
}

[numthreads(128, 1, 1)]
void ClearChunks (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;

    if (i >= maxChunks)
        return;

    // clear all chunks:
    keyvalue k;
    k.key = 0xffffffff;
    k.handle = 0xffffffff;
    hashtable[i] = k;
}

[numthreads(128, 1, 1)]
void ClearGrid (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;

    if (i >= maxCells)
        return;

    if (i == 0)
    {
        for (int l = 0; l <= GRID_LEVELS; ++l)
            levelPopulation[l] = 0;
    }

    // clear all cell counts to zero, and cell offsets to invalid.
    cellOffsets[i] = INVALID;
    cellCounts[i] = 0;
}

[numthreads(128, 1, 1)]
void InsertChunks (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
    if (i >= particleCount) return;
    
    int p = particleIndices[firstParticle + i];

    // ignore inactive particles:
    if (principalRadii[p].w <= 0.5f)
        return;
    
    // calculate particle cell index:
    int level = GridLevelForSize(fluidMaterial[p].x); 
    float cellSize = CellSizeOfLevel(level);
    int4 cellCoord = int4(floor((positions[p].xyz - solverBounds[0].min_.xyz) / cellSize),level);

    // if the solver is 2D, project to the z = 0 cell.
    if (mode == 1) cellCoord[2] = 0;

    // since cell hashes are morton-sorted during collision detection, here we also get sorted cells
    // as long as particle cellCoords are reasonably similar
    // (won't be the exact same since we use renderable positions instead of positions)
    uint cellIndex = gridHashToSortedIndex[GridHash(cellCoord)];
   
    // insert simplex in cell:
    InterlockedAdd(cellCounts[cellIndex],1, offsetInCell[p]);

     // atomically increase this level's population by one:
    InterlockedAdd(levelPopulation[1 + level],1);

    //in 3D, only particles near the surface should spawn chunks:
    //if (mode == 0 && dot(normals[p],normals[p]) < 0.0001f) 
        //return;

    // expand aabb by voxel size, since boundary voxels (at a chunks' 0 X/Y/Z) can't be triangulated.
    float radius = fluidMaterial[p].x + voxelSize;

    // calculate particle chunk span.
    float chunkSize = chunkResolution * voxelSize;
    uint3 minCell = floor((positions[p].xyz - radius - chunkGridOrigin) / chunkSize);
    uint3 maxCell = floor((positions[p].xyz + radius - chunkGridOrigin) / chunkSize);

    if (mode == 1)
        minCell[2] = maxCell[2] = 0;
        
    for (uint x = minCell[0]; x <= maxCell[0]; ++x)
    {
        for (uint y = minCell[1]; y <= maxCell[1]; ++y)
        {
            for (uint z = minCell[2]; z <= maxCell[2]; ++z)
            {
                AllocateChunk(uint3(x, y, z));
            }
        }
    }
}

[numthreads(128, 1, 1)]
void SortParticles (uint3 id : SV_DispatchThreadID) 
{
    uint i = id.x;
    if (i >= particleCount) return;

    int p = particleIndices[firstParticle + i];

    // ignore inactive particles:
    if (principalRadii[p].w <= 0.5f)
        return;

    // calculate particle cell index:
    int level = GridLevelForSize(fluidMaterial[p].x);
    float cellSize = CellSizeOfLevel(level);
    int4 cellCoord = int4(floor((positions[p].xyz - solverBounds[0].min_.xyz) / cellSize),level);

    // if the solver is 2D, project to the z = 0 cell.
    if (mode == 1) cellCoord[2] = 0;

    uint cellIndex = gridHashToSortedIndex[GridHash(cellCoord)];
    
    // find final sorted index:
    uint gridIndex = cellOffsets[cellIndex] + offsetInCell[p];

    // write particle data in sorted order:
    sortedPositions[gridIndex] = float4(positions[p].xyz, 1 / fluidData[p].x);
    sortedPrincipalRadii[gridIndex] = fluidMaterial[p].x * (principalRadii[p] / principalRadii[p].x);
    sortedVelocities[gridIndex] = float4(velocities[p].xyz, (asuint(angularVelocities[p].w) & 0x0000ffff) / 65535.0);
    sortedOrientations[gridIndex] = orientations[p];
    sortedColors[gridIndex] = colors[p];
}



